# code to estimate hierarhical LASSO
# DV: network ideal point
# IVs: average document scores estimated using pivot scaling

library(tidyverse)
library(data.table)
library(dtplyr)
library(network)

library(hierNet)

require(foreach)
require(doMC)
registerDoMC(cores = 10)

source("pundits_functions.R")

set.seed(1111)

# modivied hierNet::hiernet.cv function to store cross-validation error
hierNet.cv.yhat <- 
  function(fit, x, y, nfolds=10, folds=NULL, trace=0, lamlist = NULL, inpar = FALSE) {
  this.call <- match.call()
  stopifnot(class(fit) == "hierNet.path")
  if(fit$type=="gaussian"){errfun=function(y,yhat){(y-yhat)^2}} 
  if(fit$type=="logistic"){errfun=function(y,yhat){1*(y!=yhat)}} 
  n <- length(y)
  if(is.null(folds)) {
    folds <- split(sample(1:n), rep(1:nfolds, length = n))
  }
  else {
    stopifnot(class(folds)=="list")
    nfolds <- length(folds)
  } 
  if(is.null(lamlist)){
        lamlist=fit$lamlist
  }

  # get whether fit was standardized based on fit$sx and fit$szz...
  if (is.null(fit$mx)) stop("hierNet object was not centered.  hierNet.cv has not been written for this (unusual) case.")
  stand.main <- !is.null(fit$sx)
  stand.int <- !is.null(fit$szz)
  
  n.lamlist <- length(lamlist)        ### Set up the data structures
  size <- double(n.lamlist)
  err2=matrix(NA,nrow=nfolds,ncol=length(lamlist))
        
      if(inpar){
          
    combine_custom <- function(LL1, LL2){
        yhats <- rbind(LL1$yhats, LL2$yhats)
        err2 <- rbind(LL1$err2, LL2$err2)
        return(list(yhats = yhats, err2 = err2))
    }
          
    message("cross-validating in parallel")   
          
    parfit <- 
          foreach(ii = 1:nfolds,
                .combine=combine_custom, 
               .inorder = TRUE) %dopar% {
      if(fit$type=="gaussian"){
      a <- hierNet.path(x[-folds[[ii]],],y=y[-folds[[ii]]], 
                        lamlist=lamlist, delta=fit$delta, diagonal=fit$diagonal, strong=fit$strong, trace=trace,
                        stand.main=stand.main, stand.int=stand.int,
                        rho=fit$rho, niter=fit$niter, sym.eps=fit$sym.eps, # ADMM parameters (which will be NULL if strong=F)
                        step=fit$step, maxiter=fit$maxiter, backtrack=fit$backtrack, tol=fit$tol) # GG descent params
      
      yhatt=predict(a,newx=x[folds[[ii]],])
          
      hatdf <- data.frame(index = folds[[ii]],
                          fold = ii,
                          yhat = yhatt)
      
    }
    if(fit$type=="logistic"){
      a <- hierNet.logistic.path(x[-folds[[ii]],],y=y[-folds[[ii]]], 
                                 lamlist=lamlist, delta=fit$delta, diagonal=fit$diagonal, strong=fit$strong,
                                 trace=trace, stand.main=stand.main, stand.int=stand.int,
                                 rho=fit$rho, niter=fit$niter, sym.eps=fit$sym.eps, # ADMM parameters (which will be NULL if strong=F)
                                 step=fit$step, maxiter=fit$maxiter, backtrack=fit$backtrack, tol=fit$tol) # GG descent params                                 
      yhatt=predict(a,newx=x[folds[[ii]],])$yhat
      
        hatdf <- data.frame(index = folds[[ii]],
                            fold = ii,
                          yhat = yhatt)
    }
          
    temp=matrix(y[folds[[ii]]],nrow=length(folds[[ii]]),ncol=n.lamlist)
    err2 <- colMeans(errfun(yhatt,temp))
        
    return(list(yhats = hatdf, err2 = err2))
    cat("\n")
          }
          
     yhats <- parfit$yhats
     err2<- parfit$err2
          
      }else{
          for(ii in 1:nfolds) {
    cat("Fold", ii, ":")
    if(fit$type=="gaussian"){
      a <- hierNet.path(x[-folds[[ii]],],y=y[-folds[[ii]]], 
                        lamlist=lamlist, delta=fit$delta, diagonal=fit$diagonal, strong=fit$strong, trace=trace,
                        stand.main=stand.main, stand.int=stand.int,
                        rho=fit$rho, niter=fit$niter, sym.eps=fit$sym.eps, # ADMM parameters (which will be NULL if strong=F)
                        step=fit$step, maxiter=fit$maxiter, backtrack=fit$backtrack, tol=fit$tol) # GG descent params
      
      yhatt=predict(a,newx=x[folds[[ii]],])
      
      yhats[[ii]] <- yhatt
    }
    if(fit$type=="logistic"){
      a <- hierNet.logistic.path(x[-folds[[ii]],],y=y[-folds[[ii]]], 
                                 lamlist=lamlist, delta=fit$delta, diagonal=fit$diagonal, strong=fit$strong,
                                 trace=trace, stand.main=stand.main, stand.int=stand.int,
                                 rho=fit$rho, niter=fit$niter, sym.eps=fit$sym.eps, # ADMM parameters (which will be NULL if strong=F)
                                 step=fit$step, maxiter=fit$maxiter, backtrack=fit$backtrack, tol=fit$tol) # GG descent params                                 
      yhatt=predict(a,newx=x[folds[[ii]],])$yhat
      
      yhats[[ii]] <- yhatt
    }
    
    temp=matrix(y[folds[[ii]]],nrow=length(folds[[ii]]),ncol=n.lamlist)
    err2[ii,]=colMeans(errfun(yhatt,temp))
    cat("\n")
  }
      }
      
   
  errm=colMeans(err2)
  errse=sqrt(apply(err2,2,var)/nfolds)
  o=which.min(errm)
  lamhat=lamlist[o]
  oo=errm<= errm[o]+errse[o]
  lamhat.1se=lamlist[oo & lamlist>=lamhat][1]
  
  nonzero=colSums(fit$bp-fit$bn!=0) + apply(fit$th!=0, 3, function(a) sum(diag(a)) + sum((a+t(a)!=0)[upper.tri(a)]))
  obj <- list(lamlist=lamlist, cv.err=errm,cv.se=errse,lamhat=lamhat, lamhat.1se=lamhat.1se,
              nonzero=nonzero, folds=folds,
              yhat = yhats,
              call = this.call)
  class(obj) <- "hierNet.cv"
  obj
}

# function to routinize modeling procedure
run_predideo <- function(input, 
                         outcome, 
                         includevars,
                         mod = NULL,
                         lamlist = NULL,
                         save_path = "../output/",
                         name,
                         seed = 11111){
  require(hierNet)
  set.seed(seed)
  
  message('shaping input')
  X <- input %>%
    dplyr::select(all_of(includevars)) %>%
    as.matrix()
  
  X_scale <- as.matrix(data.frame(scale(X)))
  
  Y <- outcome
  if(any(is.na(Y))){
    message(paste0("removing ", sum(is.na(Y)), " observations due to missingness"))
    Y <- Y[-which(is.na(Y))]
    X_scale <- X_scale[-which(is.na(Y)),]
  }
  
  if(is.null(mod)){
      message("modeling...")
      mod <- hierNet.path(y = Y,
                          x = X_scale,
                          lamlist = seq(from = 10, to = 200, by = 10),
                          strong = TRUE,
                          trace = 1)
        save(mod, file = paste0(save_path, "mod_", name, ".RData"))
  }
  if(is.null(lamlist)){
        lamlist=fit$lamlist
  }
   
  mod.cv <- hierNet.cv.yhat(mod,
                          x = X_scale,
                          y = Y,
                          lamlist = mod$lamlist,
                            inpar = TRUE,
                          trace = 1)
    
  save(mod.cv, file = paste0(save_path, "modcv_yhats_", name, ".RData"))
}

pundits_agg0 <- readr::read_csv("../data/pundits_parrot_aggregated_imp0.csv")
                                            
ideal.points_d1 <- readr::read_csv("../data/ideal_points_masked.csv")
      
which.na.ts <- which(is.na(ideal.points_d1$tweetscore))

load("../output/mod_allvars.RData")
                                            
message("running full cross validated model...")
run_predideo(input = pundits_agg0,
             mod = mod,
             lamlist = "include",
             outcome = ideal.points_d1$ideal.point,
             includevars = names(pundits_agg0)[-1],
            name = "allvars")

load("../output/mod_allvars_tweetscore.RData")                                      
run_predideo(input = pundits_agg0[-which.na.ts,],
             mod = mod,
             lamlist = "include",
             outcome = ideal.points_d1$tweetscore[-which.na.ts],
             includevars = names(pundits_agg0)[-1],
            name = "allvars_tweetscore")
